import json
import os
import traceback

from sqlalchemy import create_engine, inspect
from sqlalchemy.orm import sessionmaker

from Spider4Mirror.util.logger import Logger


class MysqlUtil:

    def __init__(self):
        # 获取配置文件路径
        dir_path = os.path.split(os.path.split(__file__)[0])[0]
        # mysql_address配置在环境变量中
        if os.getenv("MYSQL_ADDRESS"):
            mysql_address = os.getenv("MYSQL_ADDRESS")
        # mysql_address配置在配置文件中
        else:
            config_path = os.path.join(dir_path, f"config{os.sep}config.json")
            with open(config_path, "r") as config_file:
                # 反序列化配置文件
                config = json.load(config_file)["mysql"]
            mysql_address = "mysql+pymysql://{user}:{password}@{host}:{port}/{database}".format(**config)
        # 创建连接引擎
        self.engine = create_engine(mysql_address, pool_size=100, pool_recycle=60 * 60)
        # 数据库会话连接
        self.db_session = sessionmaker(bind=self.engine, expire_on_commit=False)
        self.db_inspect = inspect(self.engine)
        self.log = Logger().logger

    def insert_many(self, objs):
        """
        插入多条数据
        :param objs: 模型对象
        """
        session = self.db_session()
        try:
            session.add_all(objs)
            session.commit()
            self.log.info("写入成功")
        except Exception as e:
            self.log.error("写入失败， 错误原因{0}".format(e))
            session.rollback()
            raise e
        finally:
            session.close()

    def judge_table_is_exist(self, table_name):
        """判断表是否存在"""
        if self.db_inspect.has_table(table_name):
            return True
        return False

    def query_one(self, query_param, filter_params, order_params):
        """
        查询第一条数据
        :param query_param: 查询参数，可以为模型对象或者是由需查询模型字段组成的列表，例model或者[model.id, model.name]
        :param filter_params: 过滤条件，例[model.id == 1, model.id == 2]
        :param order_params: 排序条件
        :return:查询结果
        """
        session = self.db_session()
        try:
            if isinstance(query_param, list):
                query_obj = session.query(*query_param).filter(*filter_params).order_by(*order_params).first()
            else:
                query_obj = session.query(query_param).filter(*filter_params).order_by(*order_params).first()
        except Exception as error_info:
            query_obj = None
            self.log.error("错误信息：{0}".format(traceback.format_exc()))
            self.log.error("查询出错信息：{0}".format(error_info))
        session.close()
        return query_obj


if __name__ == '__main__':
    MysqlUtil()
